{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed Data Processing using SageMaker Processing and DJL Spark Container\n",
    "\n",
    "This example notebook demonstrates on how to use Amazon SageMaker Processing with DJL Spark docker image to run distributed deep learning inference on large datasets. DJL Spark docker image is a pre-built image that includes the Deep Java Library (DJL) and other dependencies needed to run distributed data processing jobs on Amazon SageMaker. DJL is an open-source Java-based deep learning library, designed to be easy to use and compatible with existing deep learning ecosystems.\n",
    "\n",
    "By using these two services together, you can easily run distributed deep learning inference on large datasets in a scalable and cost-effective manner."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Setup](#Setup)\n",
    "1. [Push the Container to ECR](#Push-the-Container-to-ECR)\n",
    "1. [Prepare Processing Script](#Prepare-Processing-Script)\n",
    "1. [Run the SageMaker Processing Job](#Run-the-SageMaker-Processing-Job)\n",
    "1. [Monitor and Analyze Your Job](#Monitor-and-Analyze-Your-Job)\n",
    "1. [Validate Data Processing Results](#Validate-Data-Processing-Results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install the SageMaker Python SDK\n",
    "\n",
    "This notebook requires the SageMaker Python SDK."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install sagemaker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup account and role\n",
    "\n",
    "Next, you'll define the account, region and role that will be used to run the SageMaker Processing job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from time import gmtime, strftime\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "account_id = sagemaker_session.account_id()\n",
    "region = sagemaker_session.boto_region_name\n",
    "role = sagemaker.get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Push the Container to ECR\n",
    "\n",
    "The following section pulls the DJL Spark docker image from dockerhub and pushes to your ECR."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "docker_registry=\"deepjavalibrary\"\n",
    "repository_name=\"djl-spark\"\n",
    "tag=\"0.27.0-cpu\"\n",
    "ecr_registry=\"{}.dkr.ecr.{}.amazonaws.com\".format(account_id, region)\n",
    "\n",
    "# Pull the DJL Spark image\n",
    "!docker pull $docker_registry/$repository_name:$tag\n",
    "\n",
    "# Create ECR repository and push the image to your ECR\n",
    "!$(aws ecr get-login --region $region --registry-ids $account_id --no-include-email)\n",
    "repository_query = !(aws ecr describe-repositories --repository-names $repository_name)\n",
    "if repository_query[0] == '':\n",
    "    !(aws ecr create-repository --repository-name $repository_name)\n",
    "!docker tag $docker_registry/$repository_name:$tag $ecr_registry/$repository_name:$tag\n",
    "!docker push $ecr_registry/$repository_name:$tag"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Processing Script\n",
    "\n",
    "The source for the processing script is in the cell below. The cell uses the `%%writefile` directive to save this file locally.\n",
    "\n",
    "This script performs an image classification task using a pretrained resnet18 model. Note that this example does its own pre-processing using numpy and pass in the input in .npz binary format, do the inference, then does its own post-processing and loads result in .npz binary format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile ./code/process.py\n",
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "from typing import Iterator\n",
    "\n",
    "from pyspark.sql.session import SparkSession\n",
    "from pyspark.sql.functions import udf\n",
    "from pyspark.sql.types import StructType, StructField, BinaryType, FloatType, ArrayType\n",
    "\n",
    "from djl_spark.task.binary import BinaryPredictor\n",
    "from djl_spark.util import np_util\n",
    "\n",
    "\n",
    "def transform(img_data):\n",
    "    img = Image.frombytes(mode='RGB', data=img_data.data, size=[img_data.width, img_data.height])\n",
    "\n",
    "    # Resize\n",
    "    img = img.resize([224, 224], resample=Image.Resampling.BICUBIC)\n",
    "\n",
    "    # BGR to RGB\n",
    "    arr = np.flip(np.asarray(img), axis=2)\n",
    "\n",
    "    # ToTensor operation\n",
    "    arr = np.expand_dims(arr, axis=0)\n",
    "    arr = np.divide(arr, 255.0)\n",
    "    arr = arr.transpose((0, 3, 1, 2))\n",
    "    arr = np.squeeze(arr, axis=0)\n",
    "    arr = np.float32(arr)\n",
    "    return np_util.to_npz([arr])\n",
    "\n",
    "\n",
    "def transform_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:\n",
    "    for batch in iterator:\n",
    "        batch[\"transformed_image\"] = batch.apply(transform, axis=1)\n",
    "        yield batch\n",
    "\n",
    "\n",
    "def softmax(data):\n",
    "    arr = np_util.from_npz(data)[0]\n",
    "    prob = np.exp(arr) / np.sum(np.exp(arr), axis=0)\n",
    "    return prob.tolist()\n",
    "\n",
    "\n",
    "def main():\n",
    "    parser = argparse.ArgumentParser(description=\"app inputs and outputs\")\n",
    "    parser.add_argument(\"--s3_input_bucket\", type=str, help=\"s3 input bucket\")\n",
    "    parser.add_argument(\"--s3_input_key_prefix\", type=str, help=\"s3 input key prefix\")\n",
    "    parser.add_argument(\"--s3_output_bucket\", type=str, help=\"s3 output bucket\")\n",
    "    parser.add_argument(\"--s3_output_key_prefix\", type=str, help=\"s3 output key prefix\")\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    spark = SparkSession.builder.appName(\"sm-spark-djl-binary\").getOrCreate()\n",
    "\n",
    "    # Input\n",
    "    df = spark.read.format(\"image\").option(\"dropInvalid\", True).load(\"s3://\" + os.path.join(args.s3_input_bucket, args.s3_input_key_prefix))\n",
    "\n",
    "    df = df.select(\"image.*\").filter(\"nChannels=3\")  # The model expects RGB images\n",
    "\n",
    "    # Pre-processing\n",
    "    schema = StructType(df.select(\"*\").schema.fields + [\n",
    "        StructField(\"transformed_image\", BinaryType(), True)\n",
    "    ])\n",
    "    transformed_df = df.mapInPandas(transform_udf, schema)\n",
    "\n",
    "    # Inference\n",
    "    predictor = BinaryPredictor(input_col=\"transformed_image\",\n",
    "                                output_col=\"prediction\",\n",
    "                                engine=\"PyTorch\",\n",
    "                                model_url=\"djl://ai.djl.pytorch/resnet\",\n",
    "                                batchifier=\"stack\")\n",
    "    outputDf = predictor.predict(transformed_df)\n",
    "\n",
    "    # Post-processing\n",
    "    softmax_udf = udf(softmax, ArrayType(FloatType()))\n",
    "    outputDf = outputDf.withColumn(\"probabilities\", softmax_udf(outputDf['prediction']))\n",
    "\n",
    "    outputDf = outputDf.select(\"origin\", \"probabilities\")\n",
    "    outputDf.write.mode(\"overwrite\").parquet(\"s3://\" + os.path.join(args.s3_output_bucket, args.s3_output_key_prefix))\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the SageMaker Processing Job\n",
    "\n",
    "Next, we'll create a `PySparkProcessor` with the following parameters:\n",
    "\n",
    "* `base_job_name`: Set the prefix for processing name to \"sm-spark-djl\".\n",
    "* `image_uri`: Set the URI of the Docker image to the image that uploaded above. \n",
    "* `role`: Set the AWS IAM role to use for the processing job.\n",
    "* `instance_count`: Set the number of instances to run the processing job to 2.\n",
    "* `instance_type`: Set the type of EC2 instance to use for processing to \"ml.m5.xlarge\".\n",
    "\n",
    "We also set a Spark configuration:\n",
    "\n",
    "* `spark.executor.memory`: Set the amount of memory to use per executor process to 2g.\n",
    "* `spark.executor.cores`: Set the number of cores to use on each executor to 2.\n",
    "\n",
    "Then, the code calls the `run` method of the processor to start the processing job. It passes in the following parameters:\n",
    "\n",
    "* `submit_app`: The path of the processing script to submit to Spark.\n",
    "* `arguments`: List of string arguments to be passed to the processing job, including the input and output location. The input dataset we use is 300 images from the [Coco](https://cocodataset.org/#download) dataset.\n",
    "* `configuration`: Spark configuration in above.\n",
    "* `spark_event_logs_s3_uri`: S3 path where spark application events will be published to.\n",
    "* `logs`: Set whether to show the logs produced by the job to False.\n",
    "* `wait`: Set wait until the job completes to True."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.spark.processing import PySparkProcessor\n",
    "\n",
    "input_bucket = \"alpha-djl-demos\"\n",
    "input_prefix = \"dataset/cv/coco\"\n",
    "\n",
    "timestamp_prefix = strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n",
    "prefix = \"sagemaker/spark-processing-djl-demo/{}\".format(timestamp_prefix)\n",
    "output_bucket = sagemaker_session.default_bucket()\n",
    "output_prefix = f\"{prefix}/output\"\n",
    "\n",
    "image_uri = \"{}/{}:{}\".format(ecr_registry, repository_name, tag)\n",
    "\n",
    "# Run the processing job\n",
    "spark_processor = PySparkProcessor(\n",
    "    base_job_name=\"sm-spark-djl\",\n",
    "    image_uri=image_uri,\n",
    "    role=role,\n",
    "    instance_count=2,\n",
    "    instance_type=\"ml.m5.xlarge\"\n",
    ")\n",
    "\n",
    "configuration = [\n",
    "    {\n",
    "        \"Classification\": \"spark-defaults\",\n",
    "        \"Properties\": {\"spark.executor.memory\": \"2g\", \"spark.executor.cores\": \"2\"}\n",
    "    }\n",
    "]\n",
    "\n",
    "spark_processor.run(\n",
    "    submit_app=\"./code/process.py\",\n",
    "    arguments=[\n",
    "        \"--s3_input_bucket\", input_bucket,\n",
    "        \"--s3_input_key_prefix\", input_prefix,\n",
    "        \"--s3_output_bucket\", output_bucket,\n",
    "        \"--s3_output_key_prefix\", output_prefix,\n",
    "    ],\n",
    "    configuration=configuration,\n",
    "    spark_event_logs_s3_uri=\"s3://{}/{}/spark_event_logs\".format(output_bucket, prefix),\n",
    "    logs=False, # Do not show the logs produced by the job\n",
    "    wait=True # Wait until the job completes\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Monitor and Analyze Your Job\n",
    "\n",
    "Next, by calling `start_history_server()`, you can start the Spark history server and access the Spark UI to view details about the Spark application. This is useful for debugging and troubleshooting, as well as for monitoring the performance and behavior of your Spark processing job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark_processor.start_history_server()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After viewing the Spark UI, you can terminate the history server before proceeding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark_processor.terminate_history_server()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validate Data Processing Results\n",
    "\n",
    "Next, validate the output of the Spark job by ensuring that the output URI contains the Spark `_SUCCESS` file along with the output files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Output files in s3://{}/{}/\".format(output_bucket, output_prefix))\n",
    "!aws s3 ls s3://$output_bucket/$output_prefix/"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
